import pandas as pd
import numpy as np
import datetime
from src.configs.assets import UNIFIED_LABELS
from sklearn.preprocessing import OneHotEncoder
np.random.seed(510)

DUMMY_VARIABLES = [
    'Is_Married', 
    'Sex_Female', 
    'Is_Unemployed'
]
CONTINUOUS_VARIABLES = [
    'Age', 
    'Household_Size', 
    'Household_Income', 
    'Personal_Income',
]
RENAME_MAP = {
    'Is_Married_1.0':'Is_Married', 
    'Is_Married_nan': "Missing_Is_Married",
    'Sex_Female_1.0':"Sex_Female", 
    'Sex_Female_nan':"Missing_Sex_Female",
    'Is_Unemployed_1.0':"Is_Unemployed", 
    'Is_Unemployed_nan':"Missing_Is_Unemployed" 
}
def dummy(data):
    dummies = pd.get_dummies(data[DUMMY_VARIABLES], columns=DUMMY_VARIABLES, dummy_na=True, drop_first=True)
    conts = data[CONTINUOUS_VARIABLES].isna().add_prefix("Missing_")
    return data.drop(DUMMY_VARIABLES,axis=1).join(dummies).join(conts).rename(columns=RENAME_MAP)

gesis_rename_map ={
    'isNonResponse':"Nonresponse_This_Wave",
    'is_married':"Is_Married", 
    'Age':"Age", 
    'Invited_Waves':"Invited_Waves",
    'Household_Size':"Household_Size", 
    'Household_Income':"Household_Income", 
    'Personal_Income':"Personal_Income",
    'Sex_Female':"Sex_Female", 
    'Employment_not_employed':"Is_Unemployed", 
    'Employment_nan':"Missing_Employment_Status",
    'NonResponse_in_Next1Wave':"NonResponse_Next_Wave",
    "Historic_Nonresponse_Rate":"Historic_Nonresponse_Rate",
    "Wave_Date":'Wave_Date'
}

soep_rename_map = {
    'is_nonresponse':'Nonresponse_This_Wave', 
    'is_married':'Is_Married', 
    'age':"Age", 
    'Invited_Waves':"Invited_Waves", 
    'hhsize':"Household_Size",
    'hhinc':"Household_Income", 
    'pinc':"Personal_Income", 
    'gender_Female':"Sex_Female", 
    'employment_[0] Not Employed':"Is_Unemployed", 
    'employment_nan':"Missing_Employment_Status",
    'Nonresponse_Next_Wave':"NonResponse_Next_Wave", 
    'Historic_Nonresponse_Rate':"Historic_Nonresponse_Rate",
    'Wave_Date':'Wave_Date'
}

def add_historic_nonresponse_rate(data,id_label,wave_label,nonresponse_label):
    wave_set = data.index.get_level_values(0).unique().sort_values()
    out = pd.DataFrame()
    for wave in wave_set:
        new_out = data.loc[:wave,[nonresponse_label]].groupby(id_label).mean()
        new_out[wave_label] = wave
        out = pd.concat([out,new_out],axis=0)
    out = out.reset_index().set_index([wave_label,id_label])
    out.columns = ['Historic_Nonresponse_Rate']
    data = data.join(out)
    return data

##overload while trialling new ETL pipeline
def get_gesis_data():
    gesis = pd.read_csv('./data/sensitive/GESIS/gesis.csv')
    gesis['Wave_Date'] = pd.to_datetime(gesis['Wave_Date'],format="%Y-%m-%d").dt.date # "%Y/%m/%d"
    gesis['Sex_Female'] = gesis['Sex_Female'].astype(int)
    gesis['Missing_Employment_Status'] = gesis['Missing_Employment_Status'].astype(int)
    out = gesis.set_index(['Wave','Participant ID']).sort_index()[UNIFIED_LABELS]
    return dummy(out).fillna(0)

def get_soep_data():
    data = pd.read_csv('data\sensitive\SOEP\data_soep.csv', date_format="%Y/%m/%d",index_col=['Wave','Participant ID'])
    data['Wave_Date'] = pd.to_datetime(data['Wave_Date'],format="%Y-%m-%d").dt.date # "%Y/%m/%d"
    return dummy(data).fillna(0)


def get_freda_data():
    df = pd.read_csv(
        'data/sensitive/FREDA/FREDA_data.csv', date_format="%Y/%m/%d"
        ).set_index(
            ['Wave','Participant ID']
        )
    df['Wave_Date'] = pd.to_datetime(df['Wave_Date'],format="%Y-%m-%d").dt.date
    #last wave is invalid
    return dummy(df).fillna(0)

##Special case: MSC requires demographic data from wave 46 of GIP, 
## but GIP experiment requires we remove non-starter cohorts
## so we have two functions to get the two versions o GIP data
def _get_full_gip_data():
    df = pd.read_csv('data/sensitive/GIP/gip.csv',index_col=['Wave','Participant ID'])
    df['Wave_Date'] = pd.to_datetime(df['Wave_Date'], format="%Y-%m-%d").dt.date
    return df

def get_gip_data():
    data = _get_full_gip_data()
    #fill nonresponse next wave NA with 1
    data['NonResponse_Next_Wave'] = data['NonResponse_Next_Wave'].fillna(1)
    #remove last wave
    last_wave = data.index.get_level_values(0).unique().max()
    mask = data.index.get_level_values(0) != last_wave
    data = data.loc[mask,:]
    #remove all but the starting cohort
    first_wave = data.index.get_level_values(0).unique().min()
    starting_cohort_mask = data.index.get_level_values(0) == first_wave
    starting_cohort_idx = data.loc[starting_cohort_mask,:].index.get_level_values(1)
    keep_only_starting_cohort_mask =  data.index.get_level_values(1).isin(starting_cohort_idx)
    data = data.loc[keep_only_starting_cohort_mask,:]
    return dummy(data).fillna(0)

def get_mcs_data():
    df = pd.read_csv('data/sensitive/MCS/mcs.csv',index_col=['Wave','Participant ID'])
    df['Wave_Date'] = pd.to_datetime(df['Wave_Date'], format="%Y-%m-%d").dt.date
    return dummy(df).fillna(0)